# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)


import numpy as np
from scipy import integrate
from .kern import Kern
from ...core.parameterization import Param
from ...util.linalg import tdot
from ... import util
from ...util.config import config # for assesing whether to use cython
from paramz.caching import Cache_this
from paramz.transformations import Logexp


class RBF_atomic(Kern):
    """
    This class defines a special GPy covariance function for atomic
    configurations. It is a modification of the squared exponential (RBF)
    covariance function where the distance between configurations C and C'
    is based on the changes of the inter-atomic distances:

    dist(C,C') = sqrt(SUM_ij{[(1/r_ij-1/r_ij')/l_ij]^2}), where r_ij and
    r_ij' are the distances between atoms i and j in configurations C and
    C', respectively, and l_ij is the lengthscale of the corresponding
    atom pair type.

    Cov(C,C') = m^2 * exp(-0.5*dist(C,C')), where m is the magnitude of the covariance.

    The input matrices X and X2 are assumed to be ndarrays of shape N_obs x 3*N_mov,
    where each row represents one configuration including the coordinates of the moving atoms:
    x_1,y_1,z_1,x_2,y_2,z_2,...

    The parameter 'conf_info' is a dictionary including necessary information about the configurations:
    conf_info['conf_fro']: coordinates of active frozen atoms (ndarray of shape N_fro x 3)
    conf_info['atomtype_mov']: atomtype indices for moving atoms (ndarray of shape N_mov)
    conf_info['atomtype_fro']: atomtype indices for active frozen atoms (ndarray of shape N_fro)
    Atomtypes must be indexed as 0,1,2,...,n_at-1 (may include also inactive atomtypes).
    conf_info['pairtype']: pairtype indices for pairs of atomtypes (ndarray of shape n_at x n_at)
    conf_info['n_pt']: number of active pairtypes
    Active pairtypes are indexed as 0,1,2,...,n_pt-1. Inactive pairtypes are given index -1.
    
    Modified from stationary.py by:
    Olli-Pekka Koistinen, Aalto University, 2018
    """

    def __init__(self, input_dim, magnitude, lengthscale, conf_info, name='RBF_atomic'):
        super(RBF_atomic, self).__init__(input_dim, active_dims=None, name=name, useGPU=False)
        self.conf_info = conf_info
        if lengthscale is not None:
            lengthscale = np.asarray(lengthscale)
        else:
            lengthscale = np.ones(self.conf_info['n_pt'])
        self.lengthscale = Param('lengthscale', lengthscale, Logexp())
        self.magnitude = Param('magnitude', magnitude, Logexp())
        assert self.magnitude.size==1
        self.link_parameters(self.magnitude, self.lengthscale)

    def to_dict(self):
        input_dict["class"] = "GPy.kern.RBF_atomic"
        input_dict["magnitude"] =  self.magnitude.values.tolist()
        input_dict["lengthscale"] = self.lengthscale.values.tolist()
        input_dict["conf_info"] = self.conf_info
        return input_dict
    
    """
    def K_of_r(self, r):
        raise NotImplementedError("implement the covariance function as a fn of r to use this class")

    def dK_dr(self, r):
        raise NotImplementedError("implement derivative of the covariance function wrt r to use this class")

    @Cache_this(limit=3, ignore_args=())
    def dK2_drdr(self, r):
        raise NotImplementedError("implement second derivative of covariance wrt r to use this method")

    @Cache_this(limit=3, ignore_args=())
    def dK2_drdr_diag(self):
        "Second order derivative of K in r_{i,i}. The diagonal entries are always zero, so we do not give it here."
        raise NotImplementedError("implement second derivative of covariance wrt r_diag to use this method")
    """

    def K_of_dist2(self, dist2):
        return self.magnitude**2 * np.exp(-0.5 * dist2)

    def dK_ddist2(self, dist2):
        return -0.5 * self.magnitude**2 * np.exp(-0.5 * dist2)

    def dK2_ddist2dist2(self, dist2):
        return 0.25 * self.magnitude**2 * np.exp(-0.5 * dist2)

    def dK3_ddist2dist2dist2(self, dist2):
        return -0.125 * self.magnitude**2 * np.exp(-0.5 * dist2)

    @Cache_this(limit=3, ignore_args=())
    def K(self, X, X2=None):
        return self.K_of_dist2(self._dist2_atomic(X, X2))

    @Cache_this(limit=3, ignore_args=())
    def dK_ddist2_via_X(self, X, X2):
        return self.dK_ddist2(self._dist2_atomic(X, X2))

    @Cache_this(limit=3, ignore_args=())
    def dK2_ddist2dist2_via_X(self, X, X2):
        return self.dK2_ddist2dist2(self._dist2_atomic(X, X2))

    @Cache_this(limit=3, ignore_args=())
    def dK3_ddist2dist2dist2_via_X(self, X, X2):
        return self.dK3_ddist2dist2dist2(self._dist2_atomic(X, X2))

    @Cache_this(limit=3, ignore_args=())
    def ddist2_dX(self, X, X2, dimX):
        conf_fro = self.conf_info['conf_fro'] # coordinates of active frozen atoms (N_fro x 3)
        atomtype_mov = self.conf_info['atomtype_mov'] # atomtype indices for moving atoms (N_mov)
        atomtype_fro = self.conf_info['atomtype_fro'] # atomtype indices for active frozen atoms (N_fro)
        pairtype = self.conf_info['pairtype'] # pairtype indices for pairs of atomtypes (n_at x n_at)
        if X2 is None:
            X2 = X
        n1 = X.shape[0]
        n2 = X2.shape[0]
        N_mov = atomtype_mov.shape[0]
        N_fro = atomtype_fro.shape[0]
        s2 = 1.0/self.lengthscale**2
        i = int(dimX/3)
        xyz = dimX-i*3
        D1 = np.zeros([n1,n2])
        if N_mov > 1:
            for j in range(0, N_mov):
                if j != i:
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
                    deriv_ij = -2.0*s2[pairtype[atomtype_mov[i],atomtype_mov[j]]]*(invr_ij_1-invr_ij_2)
                    deriv_ij = deriv_ij*(invr_ij_1**3)*(X[:,i*3+xyz]-X[:,j*3+xyz])[:,None]
                    D1 = D1 + deriv_ij
        if N_fro > 0:
            for j in range(0, N_fro):
                invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[:,None]
                invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[None,:]
                deriv_ij = -2.0*s2[pairtype[atomtype_mov[i],atomtype_fro[j]]]*(invr_ij_1-invr_ij_2)
                deriv_ij = deriv_ij*(invr_ij_1**3)*(X[:,i*3+xyz]-conf_fro[j,xyz])[:,None]
                D1 = D1 + deriv_ij
        return D1

    @Cache_this(limit=3, ignore_args=())
    def dK_dX(self, X, X2, dimX):
        D1 = self.ddist2_dX(X, X2, dimX)
        DK = self.dK_ddist2_via_X(X, X2)
        return D1*DK

    @Cache_this(limit=3, ignore_args=())
    def dK_dX2(self, X, X2, dimX2):
        D2 = self.ddist2_dX(X2, X, dimX2).T
        DK = self.dK_ddist2_via_X(X, X2)
        return D2*DK

    @Cache_this(limit=3, ignore_args=())
    def d2dist2_dXdX2(self, X, X2, dimX, dimX2):
        conf_fro = self.conf_info['conf_fro'] # coordinates of active frozen atoms (N_fro x 3)
        atomtype_mov = self.conf_info['atomtype_mov'] # atomtype indices for moving atoms (N_mov)
        atomtype_fro = self.conf_info['atomtype_fro'] # atomtype indices for active frozen atoms (N_fro)
        pairtype = self.conf_info['pairtype'] # pairtype indices for pairs of atomtypes (n_at x n_at)
        if X2 is None:
            X2 = X
        n1 = X.shape[0]
        n2 = X2.shape[0]
        N_mov = atomtype_mov.shape[0]
        N_fro = atomtype_fro.shape[0]
        s2 = 1.0/self.lengthscale**2
        i_1 = int(dimX/3)
        xyz_1 = dimX-i_1*3
        i_2 = int(dimX2/3)
        xyz_2 = dimX2-i_2*3        
        D12 = np.zeros([n1,n2])
        if i_1 != i_2:           
            i = i_1
            j = i_2
            invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
            invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
            temp = 2.0*s2[pairtype[atomtype_mov[i],atomtype_mov[j]]]*(invr_ij_1**3)*(invr_ij_2**3)
            deriv_ij_12 = temp*(X[:,i*3+xyz_1]-X[:,j*3+xyz_1])[:,None]*(X2[:,i*3+xyz_2]-X2[:,j*3+xyz_2])[None,:]
            D12 = D12 + deriv_ij_12
        else:
            i = i_1
            if N_mov > 1:
                for j in range(0,N_mov):
                    if j != i:
                        invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
                        invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
                        temp = -2.0*s2[pairtype[atomtype_mov[i],atomtype_mov[j]]]*(invr_ij_1**3)*(invr_ij_2**3)
                        deriv_ij_12 = temp*(X[:,i*3+xyz_1]-X[:,j*3+xyz_1])[:,None]*(X2[:,i*3+xyz_2]-X2[:,j*3+xyz_2])[None,:]
                        D12 = D12 + deriv_ij_12
            if N_fro > 0:
                for j in range(0,N_fro):
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[None,:]
                    temp = -2.0*s2[pairtype[atomtype_mov[i],atomtype_fro[j]]]*(invr_ij_1**3)*(invr_ij_2**3)
                    deriv_ij_12 = temp*(X[:,i*3+xyz_1]-conf_fro[j,xyz_1])[:,None]*(X2[:,i*3+xyz_2]-conf_fro[j,xyz_2])[None,:]
                    D12 = D12 + deriv_ij_12
        return D12
    
    @Cache_this(limit=3, ignore_args=())
    def dK2_dXdX2(self, X, X2, dimX, dimX2):
        D1 = self.ddist2_dX(X, X2, dimX)
        D2 = self.ddist2_dX(X2, X, dimX2).T
        D12 = self.d2dist2_dXdX2(X, X2, dimX, dimX2)
        DK = self.dK_ddist2_via_X(X, X2)
        DK2 = self.dK2_ddist2dist2_via_X(X, X2)
        return DK*D12 + DK2*D1*D2

    @Cache_this(limit=3, ignore_args=())
    def dK_dmagnitude(self, X, X2):
        return 2*self.K(X, X2)/self.magnitude
    
    @Cache_this(limit=3, ignore_args=())
    def dK2_dmagnitudedX(self, X, X2, dim):
        return 2*self.dK_dX(X, X2, dim)/self.magnitude
    
    @Cache_this(limit=3, ignore_args=())
    def dK2_dmagnitudedX2(self, X, X2, dim):
        return 2*self.dK_dX2(X, X2, dim)/self.magnitude
    
    @Cache_this(limit=3, ignore_args=())
    def dK3_dmagnitudedXdX2(self, X, X2, dim, dimX2):
        return 2*self.dK2_dXdX2(X, X2, dim, dimX2)/self.magnitude

    @Cache_this(limit=3, ignore_args=())
    def ddist2_dlengthscale(self, X, X2):
        conf_fro = self.conf_info['conf_fro'] # coordinates of active frozen atoms (N_fro x 3)
        atomtype_mov = self.conf_info['atomtype_mov'] # atomtype indices for moving atoms (N_mov)
        atomtype_fro = self.conf_info['atomtype_fro'] # atomtype indices for active frozen atoms (N_fro)
        pairtype = self.conf_info['pairtype'] # pairtype indices for pairs of atomtypes (n_at x n_at)
        n_pt = self.conf_info['n_pt'] # number of active pairtypes
        if X2 is None:
            X2 = X
        n1 = X.shape[0]
        n2 = X2.shape[0]
        N_mov = atomtype_mov.shape[0]
        N_fro = atomtype_fro.shape[0]
        s3 = 1.0/self.lengthscale**3
        D_pt = []
        for pt in range(0, n_pt):
            D_pt += [np.zeros([n1,n2])]
        # distances between moving atoms
        if N_mov > 1:
            for i in range(0, N_mov-1):
                for j in range(i+1, N_mov):
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
                    pt = pairtype[atomtype_mov[i],atomtype_mov[j]]
                    D_pt[pt] = D_pt[pt] - 2.0*s3[pt]*(invr_ij_1-invr_ij_2)**2
        # distances from moving atoms to frozen atoms
        if N_fro > 0:
            for i in range(0, N_mov):
                for j in range(0, N_fro):
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[None,:]
                    pt = pairtype[atomtype_mov[i],atomtype_fro[j]]
                    D_pt[pt] = D_pt[pt] - 2.0*s3[pt]*(invr_ij_1-invr_ij_2)**2
        return D_pt

    @Cache_this(limit=3, ignore_args=())
    def dK_dlengthscale(self, X, X2):
        D_pt = self.ddist2_dlengthscale(X, X2)
        DK = self.dK_ddist2_via_X(X, X2)
        DK_pt = []
        for pt in range(0, len(D_pt)):
            DK_pt += [D_pt[pt]*DK]
        return DK_pt

    @Cache_this(limit=3, ignore_args=())
    def d2dist2_dlengthscaledX(self, X, X2, dimX):
        conf_fro = self.conf_info['conf_fro'] # coordinates of active frozen atoms (N_fro x 3)
        atomtype_mov = self.conf_info['atomtype_mov'] # atomtype indices for moving atoms (N_mov)
        atomtype_fro = self.conf_info['atomtype_fro'] # atomtype indices for active frozen atoms (N_fro)
        pairtype = self.conf_info['pairtype'] # pairtype indices for pairs of atomtypes (n_at x n_at)
        n_pt = self.conf_info['n_pt'] # number of active pairtypes
        if X2 is None:
            X2 = X
        n1 = X.shape[0]
        n2 = X2.shape[0]
        N_mov = atomtype_mov.shape[0]
        N_fro = atomtype_fro.shape[0]
        s3 = 1.0/self.lengthscale**3
        i = int(dimX/3)
        xyz = dimX-i*3
        D1_pt = []
        for pt in range(0, n_pt):
            D1_pt += [np.zeros([n1,n2])]
        # distances between moving atoms
        if N_mov > 1:
            for j in range(0, N_mov):
                if j != i:
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
                    pt = pairtype[atomtype_mov[i],atomtype_mov[j]]
                    deriv_ij = 4.0*s3[pt]*(invr_ij_1-invr_ij_2)
                    deriv_ij = deriv_ij*(invr_ij_1**3)*(X[:,i*3+xyz]-X[:,j*3+xyz])[:,None]
                    D1_pt[pt] = D1_pt[pt] + deriv_ij
        # distances from moving atoms to frozen atoms
        if N_fro > 0:
            for j in range(0, N_fro):
                invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[:,None]
                invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[None,:]
                pt = pairtype[atomtype_mov[i],atomtype_fro[j]]
                deriv_ij = 4.0*s3[pt]*(invr_ij_1-invr_ij_2)
                deriv_ij = deriv_ij*(invr_ij_1**3)*(X[:,i*3+xyz]-conf_fro[j,xyz])[:,None]
                D1_pt[pt] = D1_pt[pt] + deriv_ij
        return D1_pt

    @Cache_this(limit=3, ignore_args=())
    def dK2_dlengthscaledX(self, X, X2, dimX):
        D1 = self.ddist2_dX(X, X2, dimX)
        D_pt = self.ddist2_dlengthscale(X, X2)
        D1_pt = self.d2dist2_dlengthscaledX(X, X2, dimX)
        DK = self.dK_ddist2_via_X(X, X2)
        DK2 = self.dK2_ddist2dist2_via_X(X, X2)
        DK_1pt = []
        for pt in range(0, len(D_pt)):
            DK_1pt += [DK*D1_pt[pt] + DK2*D1*D_pt[pt]]
        return DK_1pt
    
    @Cache_this(limit=3, ignore_args=())
    def dK2_dlengthscaledX2(self, X, X2, dimX2):
        D2 = self.ddist2_dX(X2, X, dimX2).T
        D_pt = self.ddist2_dlengthscale(X, X2)
        D2_pt_T = self.d2dist2_dlengthscaledX(X2, X, dimX2)
        DK = self.dK_ddist2_via_X(X, X2)
        DK2 = self.dK2_ddist2dist2_via_X(X, X2)
        DK_2pt = []
        for pt in range(0, len(D_pt)):
            DK_2pt += [DK*D2_pt_T[pt].T + DK2*D2*D_pt[pt]]
        return DK_2pt
   
    @Cache_this(limit=3, ignore_args=())
    def d3dist2_dlengthscaledXdX2(self, X, X2, dimX, dimX2):
        conf_fro = self.conf_info['conf_fro'] # coordinates of active frozen atoms (N_fro x 3)
        atomtype_mov = self.conf_info['atomtype_mov'] # atomtype indices for moving atoms (N_mov)
        atomtype_fro = self.conf_info['atomtype_fro'] # atomtype indices for active frozen atoms (N_fro)
        pairtype = self.conf_info['pairtype'] # pairtype indices for pairs of atomtypes (n_at x n_at)
        n_pt = self.conf_info['n_pt'] # number of active pairtypes
        if X2 is None:
            X2 = X
        n1 = X.shape[0]
        n2 = X2.shape[0]
        N_mov = atomtype_mov.shape[0]
        N_fro = atomtype_fro.shape[0]
        s = 1.0/self.lengthscale
        i_1 = int(dimX/3)
        xyz_1 = dimX-i_1*3
        i_2 = int(dimX2/3)
        xyz_2 = dimX2-i_2*3        
        D12_pt = []
        for pt in range(0, n_pt):
            D12_pt += [np.zeros([n1,n2])]
        if i_1 != i_2:
            pt = pairtype[atomtype_mov[i_1],atomtype_mov[i_2]]
            if pt == dimX:
                i = i_1
                j = i_2
                invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
                invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
                temp = -4.0*(s[pt]*invr_ij_1*invr_ij_2)**3
                deriv_ij_12 = temp*(X[:,i*3+xyz_1]-X[:,j*3+xyz_1])[:,None]*(X2[:,i*3+xyz_2]-X2[:,j*3+xyz_2])[None,:]
                D12_pt[pt] = D12_pt[pt] + deriv_ij_12
        else:
            i = i_1
            if N_mov > 1:
                for j in range(0,N_mov):
                    if j != i:
                        invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
                        invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
                        pt = pairtype[atomtype_mov[i],atomtype_mov[j]]
                        temp = 4.0*(s[pt]*invr_ij_1*invr_ij_2)**3
                        deriv_ij_12 = temp*(X[:,i*3+xyz_1]-X[:,j*3+xyz_1])[:,None]*(X2[:,i*3+xyz_2]-X2[:,j*3+xyz_2])[None,:]
                        D12_pt[pt] = D12_pt[pt] + deriv_ij_12
            if N_fro > 0:
                for j in range(0,N_fro):
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[None,:]
                    pt = pairtype[atomtype_mov[i],atomtype_fro[j]]
                    temp = 4.0*(s[pt]*invr_ij_1*invr_ij_2)**3
                    deriv_ij_12 = temp*(X[:,i*3+xyz_1]-conf_fro[j,xyz_1])[:,None]*(X2[:,i*3+xyz_2]-conf_fro[j,xyz_2])[None,:]
                    D12_pt[pt] = D12_pt[pt] + deriv_ij_12
        return D12_pt

    @Cache_this(limit=3, ignore_args=())
    def dK3_dlengthscaledXdX2(self, X, X2, dimX, dimX2):
        D1 = self.ddist2_dX(X, X2, dimX)
        D2 = self.ddist2_dX(X2, X, dimX2).T
        D12 = self.d2dist2_dXdX2(X, X2, dimX, dimX2)
        D_pt = self.ddist2_dlengthscale(X, X2)
        D1_pt = self.d2dist2_dlengthscaledX(X, X2, dimX)
        D2_pt_T = self.d2dist2_dlengthscaledX(X2, X, dimX2)
        D12_pt = self.d3dist2_dlengthscaledXdX2(X, X2, dimX, dimX2)
        DK = self.dK_ddist2_via_X(X, X2)
        DK2 = self.dK2_ddist2dist2_via_X(X, X2)
        DK3 = self.dK3_ddist2dist2dist2_via_X(X, X2)
        DK_12pt = []
        for pt in range(0, len(D_pt)):
            DK_12pt += [DK*D12_pt[pt] + DK2*(D1*D2_pt_T[pt].T+D2*D1_pt[pt]+D_pt[pt]*D12) + DK3*D1*D2*D_pt[pt]]
        return DK_12pt
    
    @Cache_this(limit=3, ignore_args=())
    def _dist2_atomic(self, X, X2=None):
        """
        Compute the square of the atomic distance measure between each row of X and X2, or between
        each pair of rows of X if X2 is None.
        """
        conf_fro = self.conf_info['conf_fro'] # coordinates of active frozen atoms (N_fro x 3)
        atomtype_mov = self.conf_info['atomtype_mov'] # atomtype indices for moving atoms (N_mov)
        atomtype_fro = self.conf_info['atomtype_fro'] # atomtype indices for active frozen atoms (N_fro)
        pairtype = self.conf_info['pairtype'] # pairtype indices for pairs of atomtypes (n_at x n_at)
        if X2 is None:
            X2 = X
        n1 = X.shape[0]
        n2 = X2.shape[0]
        N_mov = atomtype_mov.shape[0]
        N_fro = atomtype_fro.shape[0]
        s2 = 1.0/self.lengthscale**2
        dist2 = np.zeros([n1,n2])
        # distances between moving atoms
        if N_mov > 1:
            for i in range(0, N_mov-1):
                for j in range(i+1, N_mov):
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-X[:,(j*3):(j*3+3)])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-X2[:,(j*3):(j*3+3)])**2,1))[None,:]
                    dist2 = dist2 + s2[pairtype[atomtype_mov[i],atomtype_mov[j]]]*(invr_ij_1-invr_ij_2)**2
        # distances from moving atoms to frozen atoms
        if N_fro > 0:
            for i in range(0, N_mov):
                for j in range(0, N_fro):
                    invr_ij_1 = 1.0/np.sqrt(np.sum((X[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[:,None]
                    invr_ij_2 = 1.0/np.sqrt(np.sum((X2[:,(i*3):(i*3+3)]-conf_fro[j,0:3][None,:])**2,1))[None,:]
                    dist2 = dist2 + s2[pairtype[atomtype_mov[i],atomtype_fro[j]]]*(invr_ij_1-invr_ij_2)**2
        return dist2

    def Kdiag(self, X):
        ret = np.empty(X.shape[0])
        ret[:] = self.magnitude**2
        return ret

    def reset_gradients(self):
        self.magnitude.gradient = 0.
        self.lengthscale.gradient = np.zeros(self.conf_info['n_pt'])

    def update_gradients_diag(self, dL_dKdiag, X):
        """
        Given the derivative of the objective with respect to the diagonal of
        the covariance matrix, compute the derivative wrt the parameters of
        this kernel and store in the <parameter>.gradient field.

        See also update_gradients_full
        """
        self.magnitude.gradient = 2*np.sum(dL_dKdiag)*self.magnitude
        self.lengthscale.gradient = np.zeros(self.conf_info['n_pt'])

    def update_gradients_full(self, dL_dK, X, X2=None, reset=True):
        """
        Given the derivative of the objective wrt the covariance matrix
        (dL_dK), compute the gradient wrt the parameters of this kernel,
        and store in the parameters object as e.g. self.variance.gradient
        """
        self.magnitude.gradient = 2*np.sum(self.K(X, X2)*dL_dK)/self.magnitude
        D_pt = self.dK_dlengthscale(X, X2)
        grad = []
        for pt in range(0, self.conf_info['n_pt']):
            grad += [np.sum(dL_dK*D_pt[pt])]
        self.lengthscale.gradient = np.array(grad)

    def update_gradients_direct(self, dL_dMag, dL_dLen):
        """
        Specially intended for the Grid regression case.
        Given the computed log likelihood derivates, update the corresponding
        kernel and likelihood gradients.
        Useful for when gradients have been computed a priori.
        """
        self.magnitude.gradient = dL_dMag
        self.lengthscale.gradient = dL_dLen
    
    def dgradients_dX(self, X, X2, dimX):
        g1 = self.dK2_dmagnitudedX(X, X2, dimX)
        g2 = self.dK2_dlengthscaledX(X, X2, dimX)
        return [g1, g2]

    def dgradients_dX2(self, X, X2, dimX2):
        g1 = self.dK2_dmagnitudedX2(X, X2, dimX2)
        g2 = self.dK2_dlengthscaledX2(X, X2, dimX2)
        return [g1, g2]

    def dgradients2_dXdX2(self, X, X2, dimX, dimX2):
        g1 = self.dK3_dmagnitudedXdX2(X, X2, dimX, dimX2)
        g2 = self.dK3_dlengthscaledXdX2(X, X2, dimX, dimX2)
        return [g1, g2]

    def input_sensitivity(self, summarize=True):
        return self.magnitude**2*np.ones(self.conf_info['n_pt'])/self.lengthscale**2

    def get_one_dimensional_kernel(self, dimensions):
        """
        Specially intended for the grid regression case
        For a given covariance kernel, this method returns the corresponding kernel for
        a single dimension. The resulting values can then be used in the algorithm for
        reconstructing the full covariance matrix.
        """
        raise NotImplementedError("implement one dimensional variation of kernel")



